2b04c7
@@ -24,19 +24,23 @@
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.IntWritable;
 
 public abstract class GenericUDFLeadLag extends GenericUDF
 {
 	transient ExprNodeEvaluator exprEvaluator;
 	transient PTFPartitionIterator<Object> pItr;
 	ObjectInspector firstArgOI;
-
-	private PrimitiveObjectInspector amtOI;
+	ObjectInspector defaultArgOI;
+	Converter defaultValueConverter;
+	int amt;
 
 	static{
 		PTFUtils.makeTransient(GenericUDFLeadLag.class, "exprEvaluator");
@@ -46,27 +50,30 @@
 	@Override
 	public Object evaluate(DeferredObject[] arguments) throws HiveException
 	{
-		DeferredObject amt = arguments[1];
-		int intAmt = 0;
-		try
-		{
-			intAmt = PrimitiveObjectInspectorUtils.getInt(amt.get(), amtOI);
-		}
-		catch (NullPointerException e)
-		{
-			intAmt = Integer.MAX_VALUE;
-		}
-		catch (NumberFormatException e)
-		{
-			intAmt = Integer.MAX_VALUE;
-		}
+    Object defaultVal = null;
+    if(arguments.length == 3){
+      defaultVal =  ObjectInspectorUtils.copyToStandardObject(
+          defaultValueConverter.convert(arguments[2].get()),
+          defaultArgOI);
+    }
 
 		int idx = pItr.getIndex() - 1;
+		int start = 0;
+		int end = pItr.getPartition().size();
 		try
 		{
-			Object row = getRow(intAmt);
-			Object ret = exprEvaluator.evaluate(row);
-			ret = ObjectInspectorUtils.copyToStandardObject(ret, firstArgOI, ObjectInspectorCopyOption.WRITABLE);
+		  Object ret = null;
+		  int newIdx = getIndex(amt);
+
+		  if(newIdx >= end || newIdx < start) {
+        ret = defaultVal;
+		  }
+		  else {
+        Object row = getRow(amt);
+        ret = exprEvaluator.evaluate(row);
+        ret = ObjectInspectorUtils.copyToStandardObject(ret,
+            firstArgOI, ObjectInspectorCopyOption.WRITABLE);
+		  }
 			return ret;
 		}
 		finally
@@ -83,25 +90,41 @@
public Object evaluate(DeferredObject[] arguments) throws HiveException
 	public ObjectInspector initialize(ObjectInspector[] arguments)
 			throws UDFArgumentException
 	{
-		// index has to be a primitive
-		if (arguments[1] instanceof PrimitiveObjectInspector)
-		{
-			amtOI = (PrimitiveObjectInspector) arguments[1];
-		}
-		else
-		{
-			throw new UDFArgumentTypeException(1,
-					"Primitive Type is expected but "
-							+ arguments[1].getTypeName() + "\" is found");
-		}
-
-		firstArgOI = arguments[0];
-		return ObjectInspectorUtils.getStandardObjectInspector(firstArgOI,
-				ObjectInspectorCopyOption.WRITABLE);
+    if (!(arguments.length >= 1 && arguments.length <= 3)) {
+      throw new UDFArgumentTypeException(arguments.length - 1,
+          "Incorrect invocation of " + _getFnName() + ": _FUNC_(expr, amt, default)");
+    }
+
+    amt = 1;
+
+    if (arguments.length > 1) {
+      ObjectInspector amtOI = arguments[1];
+      if ( !ObjectInspectorUtils.isConstantObjectInspector(amtOI) ||
+          (amtOI.getCategory() != ObjectInspector.Category.PRIMITIVE) ||
+          ((PrimitiveObjectInspector)amtOI).getPrimitiveCategory() !=
+          PrimitiveObjectInspector.PrimitiveCategory.INT )
+      {
+        throw new UDFArgumentTypeException(0,
+            _getFnName() + " amount must be a integer value "
+            + amtOI.getTypeName() + " was passed as parameter 1.");
+      }
+      Object o = ((ConstantObjectInspector)amtOI).
+          getWritableConstantValue();
+      amt = ((IntWritable)o).get();
+    }
+
+    if (arguments.length == 3) {
+      defaultArgOI = arguments[2];
+      ObjectInspectorConverters.getConverter(arguments[2], arguments[0]);
+      defaultValueConverter = ObjectInspectorConverters.getConverter(arguments[2], arguments[0]);
+
+    }
+
+    firstArgOI = arguments[0];
+    return ObjectInspectorUtils.getStandardObjectInspector(firstArgOI,
+        ObjectInspectorCopyOption.WRITABLE);
 	}
 
-
-
 	public ExprNodeEvaluator getExprEvaluator()
 	{
 		return exprEvaluator;
@@ -122,7 +145,39 @@
public void setpItr(PTFPartitionIterator<Object> pItr)
 		this.pItr = pItr;
 	}
 
-	@Override
+	public ObjectInspector getFirstArgOI() {
+    return firstArgOI;
+  }
+
+  public void setFirstArgOI(ObjectInspector firstArgOI) {
+    this.firstArgOI = firstArgOI;
+  }
+
+  public ObjectInspector getDefaultArgOI() {
+    return defaultArgOI;
+  }
+
+  public void setDefaultArgOI(ObjectInspector defaultArgOI) {
+    this.defaultArgOI = defaultArgOI;
+  }
+
+  public Converter getDefaultValueConverter() {
+    return defaultValueConverter;
+  }
+
+  public void setDefaultValueConverter(Converter defaultValueConverter) {
+    this.defaultValueConverter = defaultValueConverter;
+  }
+
+  public int getAmt() {
+    return amt;
+  }
+
+  public void setAmt(int amt) {
+    this.amt = amt;
+  }
+
+  @Override
 	public String getDisplayString(String[] children)
 	{
 		assert (children.length == 2);
@@ -140,6 +195,8 @@
public String getDisplayString(String[] children)
 
 	protected abstract Object getRow(int amt);
 
+	protected abstract int getIndex(int amt);
+
 	public static class GenericUDFLead extends GenericUDFLeadLag
 	{
 
@@ -149,6 +206,11 @@
protected String _getFnName()
 			return "lead";
 		}
 
+		@Override
+		protected int getIndex(int amt) {
+		  return pItr.getIndex() - 1 + amt;
+		}
+
 		@Override
 		protected Object getRow(int amt)
 		{
@@ -165,6 +227,11 @@
protected String _getFnName()
 			return "lag";
 		}
 
+		@Override
+    protected int getIndex(int amt) {
+      return pItr.getIndex() - 1 - amt;
+    }
+
 		@Override
 		protected Object getRow(int amt)
 		{
